import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt

MST = [
    "#f6ede4",
    "#f3e7db",
    "#f7ead0",
    "#eadaba",
    "#d7bd96",
    "#a07e56",
    "#825c43",
    "#604134",
    "#3a312a",
    "#292420",
]


matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
matplotlib.rcParams["axes.axisbelow"] = True
colors_15 = [
    # https://jacksonlab.agronomy.wisc.edu/2016/05/23/15-level-colorblind-friendly-palette/
    "#000000",
    "#004949",
    "#009292",
    "#ff6db6",
    "#ffb6db",
    "#490092",
    "#006ddb",
    "#b66dff",
    "#6db6ff",
    "#b6dbff",
    "#920000",
    "#924900",
    "#db6d00",
    "#24ff24",
    "#ffff6d",
]

def pre_plot():
    plt.gca().tick_params(which="both", bottom=False, left=False, right=False)
    plt.tight_layout()
    plt.grid(linestyle=":")


def post_plot():
    for spine in ("top", "right", "bottom", "left"):
        plt.gca().spines[spine].set_visible(False)


def plot_se_comp(
    stats,
    add_st_wt,
    fp_out="plots/results_se_comparison",
):

    if add_st_wt:
        fp_out = f"{fp_out}_sensitivity.pdf"
    else:
        fp_out = f"{fp_out}.pdf"

    # colors and labels
    age_cols = [
        "0-10",
        "10-20",
        "20-30",
        "30-40",
        "40-50",
        "50-60",
        "60-70",
        "70-80",
        "80-90",
    ]
    cat_cols = {
        "Perceived Gender": [
            "feminine_presenting",
            "masculine_presenting",
            "non_binary_presenting",
        ],
        "Monk Skin Tone": [i + 1 for i in range(10)],
        "Perceived Age": age_cols,
    }
    gray = sns.color_palette("gray", n_colors=len(age_cols)).as_hex()
    gray = list(gray)[::-1]
    g_colors = list(sns.color_palette("pastel").as_hex())
    colors = {
        "Perceived Gender": [g_colors[6], g_colors[0], g_colors[4]],
        "Monk Skin Tone": MST,
        "Perceived Age": gray,
    }

    fig = plt.figure(figsize=(12, 4))
    cats = ["Monk Skin Tone"] if add_st_wt else cat_cols.keys()
    for i, cat in enumerate(cats):
        ax = fig.add_subplot(1, len(cat_cols), i + 1)
        plot_data = stats.melt(id_vars="sengine", value_vars=cat_cols[cat])
        plot_data["sengine"] = plot_data.sengine.str.title()

        pre_plot()
        g = sns.barplot(
            data=plot_data,
            x="variable",
            y="value",
            hue="sengine",
            color="lightgrey",
            alpha=0.99,
        )

        # colors and hatches
        for bars, hatch, leg in zip(g.containers, ["", "/"], g.legend_.legend_handles):
            for bar, color in zip(bars, colors[cat]):
                bar.set_facecolor(color)
                bar.set_hatch(hatch + hatch)
            leg.set_hatch(hatch + hatch)

        if cat == "Perceived Gender":
            g.set_xticklabels(
                [
                    "Feminine\nPresenting",
                    "Masculine\nPresenting",
                    "Non-binary\nPresenting",
                ]
            )
            ylabel = "Fraction of Images"
            g.get_legend().remove()
        elif cat == "Monk Skin Tone":
            ylabel = ""
            g.get_legend().set_title("Search Engine")
        elif cat == "Perceived Age":
            g.tick_params(axis="x", rotation=60)
            ylabel = ""
            g.get_legend().remove()
        g.set(xlabel=cat, ylabel=ylabel)

        post_plot()

    plt.savefig(fp_out, bbox_inches="tight")


def plot_intersectional(stats, cat, fp_out):
    stats = stats.reset_index()
    stats = stats[stats.gender != "unsure"]
    stats["gender"] = stats.gender.replace(
        {
            "feminine_presenting": "Feminine Presenting",
            "masculine_presenting": "Masculine Presenting",
            "non_binary_presenting": "Non-binary Presenting",
        }
    )

    plt.figure(figsize=(5, 3))
    pre_plot()
    g_colors = list(sns.color_palette("pastel").as_hex())
    g = sns.barplot(
        data=stats,
        x=cat,
        y=0,
        hue="gender",
        palette=[g_colors[6], g_colors[0], g_colors[4]],
        saturation=1,
    )
    g.get_legend().set_title("Perceived Gender")
    if cat == "skin_tone":
        g.set_ylim(0, 0.16)
    elif cat == "age_bin":
        g.tick_params(axis="x", rotation=60)
    xlabel = "Monk Skin Tone" if cat == "skin_tone" else "Perceived Age"
    g.set(xlabel=xlabel, ylabel="Fraction of Images")
    post_plot()
    plt.savefig(fp_out, bbox_inches="tight")


def plot_cat_comps(stats):

    ylabels = {
        "feminine_presenting": "Fraction Feminine Presenting",
        "skin_tone": "Mean Monk Skin Tone",
        "age": "Mean Age",
    }
    for k, v in stats.items():
        print(v)

    order = stats["feminine_presenting"].category
    cmap = {cat: color for cat, color in zip(order, colors_15)}
    for outcome, label in ylabels.items():
        plot_cat_comp(outcome, stats[outcome], label, cmap)


def plot_cat_comp(
    outcome,
    stats,
    label,
    cmap,
    fp_out="plots/results_cat_comparison_img_wt",
):

    plt.figure(figsize=(5, 3.5))
    pre_plot()
    plot_data = stats.melt(id_vars="category", value_vars=["est", "lb", "ub"])
    g = sns.barplot(
        data=plot_data,
        x="value",
        y="category",
        order=stats.category,
        palette=cmap,
    )
    if outcome == "feminine_presenting":
        plt.axvline(0.504, ls="dashed", color="red")
    elif outcome == "age":
        plt.axvline(38.9, ls="dashed", color="red")
    else:
        plt.xlim(1, 6)
        plt.axvline(5.5, ls="dashed", color="red")
    g.set(xlabel=label, ylabel="Category")
    g.set_yticklabels(stats.adj_category)
    post_plot()
    plt.savefig(f"{fp_out}_{outcome}.pdf", bbox_inches="tight")
